{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "38f7adeb-116f-4081-9e0a-2c44aef3cc95",
   "metadata": {},
   "source": [
    "- references\n",
    "    - https://github.com/kuangliu/pytorch-cifar\n",
    "    - https://github.com/Siddharth1698/ResNet-on-CIFAR-10-dataset.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d8b10081-0d65-4998-b7ae-7c3d024bf1ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision.models import resnet18, resnet34, resnet101, resnet152\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c08660e1-4a0a-4b83-b9ea-5f4ce6fa4ef5",
   "metadata": {},
   "outputs": [],
   "source": [
    "classes = ('plane', 'car', 'bird', 'cat', 'deer',\n",
    "           'dog', 'frog', 'horse', 'ship', 'truck')\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "280216d8-3a97-4127-b8f1-5f0fef5a9f71",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model):\n",
    "    model.train()\n",
    "    train_loss = 0\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    # for batch_idx, (inputs, targets) in tqdm(enumerate(train_loader), total=len(train_loader)):\n",
    "    for batch_idx, (inputs, targets) in enumerate(train_loader):\n",
    "        inputs, targets = inputs.to(device), targets.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, targets)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        train_loss += loss.item()\n",
    "        _, predicted = outputs.max(1)\n",
    "        total += targets.size(0)\n",
    "        correct += predicted.eq(targets).sum().item()\n",
    "\n",
    "    # print(f'train loss: {train_loss/len(train_loader):.3f} | acc: {100.*correct/total:.3f}({correct}/{total})')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "956fc617-02f4-4dfb-ac23-aec6a029712e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(model):\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, (inputs, targets) in tqdm(enumerate(test_loader), total=len(test_loader)):\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            outputs = model(inputs)\n",
    "            loss = criterion(outputs, targets)\n",
    "\n",
    "            test_loss += loss.item()\n",
    "            _, predicted = outputs.max(1)\n",
    "            total += targets.size(0)\n",
    "            correct += predicted.eq(targets).sum().item()\n",
    "\n",
    "    print(f'test loss: {test_loss/len(test_loader):.3f} | acc: {100.*correct/total:.3f}%({correct}/{total})')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f88c842b-8741-4476-832d-1a75ac8eba5e",
   "metadata": {},
   "source": [
    "## training from scartch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "486a328e-5d1a-4c8b-b010-a9737bc199da",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform_train = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "\n",
    "transform_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "99b34e93-6a53-4215-85ee-d1dd0db555a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "train_set = torchvision.datasets.CIFAR10(\n",
    "    root='./data', train=True, download=True, transform=transform_train)\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    train_set, batch_size=1024, shuffle=True, num_workers=2)\n",
    "\n",
    "test_set = torchvision.datasets.CIFAR10(\n",
    "    root='./data', train=False, download=True, transform=transform_test)\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    test_set, batch_size=250, shuffle=False, num_workers=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "497b6b4e-9eb9-4960-bd45-b54cb5dfbdec",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "48.828125"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_set) / train_loader.batch_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "3e9d0e10-3e54-4c81-9c12-8234be02e7a5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "35ba8ba265464fabaad4cabab5c20591",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "967ca50b3e584c16abe35f0b43aa6f0e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/40 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test loss: 1.066 | acc: 63.140%(6314/10000)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5c76bae94b53461396e51d305c69f1ca",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/40 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test loss: 0.833 | acc: 73.120%(7312/10000)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b4dccba5201b40d485dca6e0345e52b8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/40 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test loss: 1.133 | acc: 73.570%(7357/10000)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "69f1eafae28a4a04a4fe3bba0922a2c9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/40 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test loss: 1.591 | acc: 73.710%(7371/10000)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "91e6b706ebb34c0eb48cc4be3df38a89",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/40 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test loss: 1.909 | acc: 75.220%(7522/10000)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "feb9548d1e93450daa44c548652f6cf5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/40 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test loss: 2.126 | acc: 75.790%(7579/10000)\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[75], line 19\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[38;5;66;03m# model.to(device)\u001b[39;00m\n\u001b[1;32m     18\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(max_epochs)):\n\u001b[0;32m---> 19\u001b[0m     \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     20\u001b[0m     scheduler\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m     21\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m (epoch\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m%\u001b[39m \u001b[38;5;241m5\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
      "Cell \u001b[0;32mIn[56], line 13\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(model)\u001b[0m\n\u001b[1;32m     11\u001b[0m loss \u001b[38;5;241m=\u001b[39m criterion(outputs, targets)\n\u001b[1;32m     12\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[0;32m---> 13\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     15\u001b[0m train_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mitem()\n\u001b[1;32m     16\u001b[0m _, predicted \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mmax(\u001b[38;5;241m1\u001b[39m)\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/optim/lr_scheduler.py:69\u001b[0m, in \u001b[0;36mLRScheduler.__init__.<locals>.with_counter.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     67\u001b[0m instance\u001b[38;5;241m.\u001b[39m_step_count \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m     68\u001b[0m wrapped \u001b[38;5;241m=\u001b[39m func\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__get__\u001b[39m(instance, \u001b[38;5;28mcls\u001b[39m)\n\u001b[0;32m---> 69\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mwrapped\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/optim/optimizer.py:280\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    276\u001b[0m         \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    277\u001b[0m             \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs),\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    278\u001b[0m                                \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbut got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 280\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    281\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m    283\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/optim/optimizer.py:33\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.<locals>._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m     31\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m     32\u001b[0m     torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m'\u001b[39m])\n\u001b[0;32m---> 33\u001b[0m     ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     34\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m     35\u001b[0m     torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(prev_grad)\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/optim/adam.py:141\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m    130\u001b[0m     beta1, beta2 \u001b[38;5;241m=\u001b[39m group[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbetas\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m    132\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_init_group(\n\u001b[1;32m    133\u001b[0m         group,\n\u001b[1;32m    134\u001b[0m         params_with_grad,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    138\u001b[0m         max_exp_avg_sqs,\n\u001b[1;32m    139\u001b[0m         state_steps)\n\u001b[0;32m--> 141\u001b[0m     \u001b[43madam\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    142\u001b[0m \u001b[43m        \u001b[49m\u001b[43mparams_with_grad\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    143\u001b[0m \u001b[43m        \u001b[49m\u001b[43mgrads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    144\u001b[0m \u001b[43m        \u001b[49m\u001b[43mexp_avgs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    145\u001b[0m \u001b[43m        \u001b[49m\u001b[43mexp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    146\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmax_exp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    147\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstate_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    148\u001b[0m \u001b[43m        \u001b[49m\u001b[43mamsgrad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mamsgrad\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    149\u001b[0m \u001b[43m        \u001b[49m\u001b[43mbeta1\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta1\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    150\u001b[0m \u001b[43m        \u001b[49m\u001b[43mbeta2\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta2\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    151\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mlr\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    152\u001b[0m \u001b[43m        \u001b[49m\u001b[43mweight_decay\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mweight_decay\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    153\u001b[0m \u001b[43m        \u001b[49m\u001b[43meps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43meps\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    154\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmaximize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmaximize\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    155\u001b[0m \u001b[43m        \u001b[49m\u001b[43mforeach\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mforeach\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    156\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcapturable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mcapturable\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    157\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdifferentiable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mdifferentiable\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    158\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfused\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mfused\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    159\u001b[0m \u001b[43m        \u001b[49m\u001b[43mgrad_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mgrad_scale\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    160\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfound_inf\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfound_inf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    161\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    163\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/optim/adam.py:281\u001b[0m, in \u001b[0;36madam\u001b[0;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach, capturable, differentiable, fused, grad_scale, found_inf, amsgrad, beta1, beta2, lr, weight_decay, eps, maximize)\u001b[0m\n\u001b[1;32m    278\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    279\u001b[0m     func \u001b[38;5;241m=\u001b[39m _single_tensor_adam\n\u001b[0;32m--> 281\u001b[0m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    282\u001b[0m \u001b[43m     \u001b[49m\u001b[43mgrads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    283\u001b[0m \u001b[43m     \u001b[49m\u001b[43mexp_avgs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    284\u001b[0m \u001b[43m     \u001b[49m\u001b[43mexp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    285\u001b[0m \u001b[43m     \u001b[49m\u001b[43mmax_exp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    286\u001b[0m \u001b[43m     \u001b[49m\u001b[43mstate_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    287\u001b[0m \u001b[43m     \u001b[49m\u001b[43mamsgrad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mamsgrad\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    288\u001b[0m \u001b[43m     \u001b[49m\u001b[43mbeta1\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta1\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    289\u001b[0m \u001b[43m     \u001b[49m\u001b[43mbeta2\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta2\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    290\u001b[0m \u001b[43m     \u001b[49m\u001b[43mlr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    291\u001b[0m \u001b[43m     \u001b[49m\u001b[43mweight_decay\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mweight_decay\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    292\u001b[0m \u001b[43m     \u001b[49m\u001b[43meps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43meps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    293\u001b[0m \u001b[43m     \u001b[49m\u001b[43mmaximize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmaximize\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    294\u001b[0m \u001b[43m     \u001b[49m\u001b[43mcapturable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcapturable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    295\u001b[0m \u001b[43m     \u001b[49m\u001b[43mdifferentiable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdifferentiable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    296\u001b[0m \u001b[43m     \u001b[49m\u001b[43mgrad_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgrad_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    297\u001b[0m \u001b[43m     \u001b[49m\u001b[43mfound_inf\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfound_inf\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/optim/adam.py:505\u001b[0m, in \u001b[0;36m_multi_tensor_adam\u001b[0;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, grad_scale, found_inf, amsgrad, beta1, beta2, lr, weight_decay, eps, maximize, capturable, differentiable)\u001b[0m\n\u001b[1;32m    503\u001b[0m     denom \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39m_foreach_add(max_exp_avg_sq_sqrt, eps)\n\u001b[1;32m    504\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 505\u001b[0m     exp_avg_sq_sqrt \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_foreach_sqrt\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice_exp_avg_sqs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    506\u001b[0m     torch\u001b[38;5;241m.\u001b[39m_foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)\n\u001b[1;32m    507\u001b[0m     denom \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39m_foreach_add(exp_avg_sq_sqrt, eps)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# lr = 0.1\n",
    "# lr = 0.01, 74%\n",
    "model = resnet18(weights=None, num_classes=len(classes)).to(device) \n",
    "# lr = 0.01, 74%\n",
    "# lr = 0.005, 76% \n",
    "# model = resnet34(weights=None, num_classes=len(classes)).to(device) \n",
    "# lr = 0.01, 68%\n",
    "# lr = 0.005, 66%\n",
    "# model = resnet101(weights=None, num_classes=len(classes)).to(device) \n",
    "# lr = 0.01, \n",
    "# lr = 0.005, \n",
    "# model = resnet152(weights=None, num_classes=len(classes)).to(device)\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-2)\n",
    "\n",
    "max_epochs = 50\n",
    "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)\n",
    "# model.to(device)\n",
    "for epoch in tqdm(range(max_epochs)):\n",
    "    train(model)\n",
    "    scheduler.step()\n",
    "    if (epoch+1) % 5 == 0:\n",
    "        test(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27f96e00-98e7-48c6-92a0-7c11a2499206",
   "metadata": {},
   "source": [
    "## training from pretrained"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "895cd833-0774-4fe4-b86f-2ded6f830a1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform_train = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "\n",
    "transform_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "\n",
    "train_set = torchvision.datasets.CIFAR10(\n",
    "    root='./data', train=True, download=True, transform=transform_train)\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    train_set, batch_size=1024, shuffle=True, num_workers=2)\n",
    "\n",
    "test_set = torchvision.datasets.CIFAR10(\n",
    "    root='./data', train=False, download=True, transform=transform_test)\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    test_set, batch_size=250, shuffle=False, num_workers=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c4179c2-40ba-4035-a0ef-8bb4b356da54",
   "metadata": {},
   "source": [
    "### Vanilla FC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "b8df6148-3ce9-4495-ac52-278c8112f6b1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading: \"https://download.pytorch.org/models/resnet34-b627a593.pth\" to /home/whaow/.cache/torch/hub/checkpoints/resnet34-b627a593.pth\n",
      "100%|██████████| 83.3M/83.3M [00:01<00:00, 64.0MB/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "28bf4c5baa2843a5affc58e4830e98b0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f0e4d08fe1324063809fb239c46712f7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/40 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test loss: 0.633 | acc: 82.750%(8275/10000)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "58188954ad2349fc883eeb38e3801488",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/40 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test loss: 0.777 | acc: 82.610%(8261/10000)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "233e1c5a734b4bd6a69b4837bd2e18a2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/40 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test loss: 0.977 | acc: 83.030%(8303/10000)\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[74], line 15\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[38;5;66;03m# model.to(device)\u001b[39;00m\n\u001b[1;32m     14\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(max_epochs)):\n\u001b[0;32m---> 15\u001b[0m     \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     16\u001b[0m     scheduler\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m     17\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m (epoch\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m%\u001b[39m \u001b[38;5;241m5\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
      "Cell \u001b[0;32mIn[56], line 7\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(model)\u001b[0m\n\u001b[1;32m      5\u001b[0m total \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m      6\u001b[0m \u001b[38;5;66;03m# for batch_idx, (inputs, targets) in tqdm(enumerate(train_loader), total=len(train_loader)):\u001b[39;00m\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch_idx, (inputs, targets) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(train_loader):\n\u001b[1;32m      8\u001b[0m     inputs, targets \u001b[38;5;241m=\u001b[39m inputs\u001b[38;5;241m.\u001b[39mto(device), targets\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m      9\u001b[0m     optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py:633\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    630\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    631\u001b[0m     \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m    632\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset()  \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 633\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    634\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m    635\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m    636\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m    637\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1328\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1325\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_process_data(data)\n\u001b[1;32m   1327\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_shutdown \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m-> 1328\u001b[0m idx, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1329\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m   1330\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable:\n\u001b[1;32m   1331\u001b[0m     \u001b[38;5;66;03m# Check for _IterableDatasetStopIteration\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1294\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._get_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1290\u001b[0m     \u001b[38;5;66;03m# In this case, `self._data_queue` is a `queue.Queue`,. But we don't\u001b[39;00m\n\u001b[1;32m   1291\u001b[0m     \u001b[38;5;66;03m# need to call `.task_done()` because we don't use `.join()`.\u001b[39;00m\n\u001b[1;32m   1292\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   1293\u001b[0m     \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m-> 1294\u001b[0m         success, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_try_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1295\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m success:\n\u001b[1;32m   1296\u001b[0m             \u001b[38;5;28;01mreturn\u001b[39;00m data\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1132\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._try_get_data\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m   1119\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_try_get_data\u001b[39m(\u001b[38;5;28mself\u001b[39m, timeout\u001b[38;5;241m=\u001b[39m_utils\u001b[38;5;241m.\u001b[39mMP_STATUS_CHECK_INTERVAL):\n\u001b[1;32m   1120\u001b[0m     \u001b[38;5;66;03m# Tries to fetch data from `self._data_queue` once for a given timeout.\u001b[39;00m\n\u001b[1;32m   1121\u001b[0m     \u001b[38;5;66;03m# This can also be used as inner loop of fetching without timeout, with\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1129\u001b[0m     \u001b[38;5;66;03m# Returns a 2-tuple:\u001b[39;00m\n\u001b[1;32m   1130\u001b[0m     \u001b[38;5;66;03m#   (bool: whether successfully get data, any: data if successful else None)\u001b[39;00m\n\u001b[1;32m   1131\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1132\u001b[0m         data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_queue\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1133\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mTrue\u001b[39;00m, data)\n\u001b[1;32m   1134\u001b[0m     \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m   1135\u001b[0m         \u001b[38;5;66;03m# At timeout and error, we manually check whether any worker has\u001b[39;00m\n\u001b[1;32m   1136\u001b[0m         \u001b[38;5;66;03m# failed. Note that this is the only mechanism for Windows to detect\u001b[39;00m\n\u001b[1;32m   1137\u001b[0m         \u001b[38;5;66;03m# worker failures.\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/multiprocessing/queues.py:113\u001b[0m, in \u001b[0;36mQueue.get\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m    111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m block:\n\u001b[1;32m    112\u001b[0m     timeout \u001b[38;5;241m=\u001b[39m deadline \u001b[38;5;241m-\u001b[39m time\u001b[38;5;241m.\u001b[39mmonotonic()\n\u001b[0;32m--> 113\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_poll\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m    114\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m Empty\n\u001b[1;32m    115\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_poll():\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/multiprocessing/connection.py:257\u001b[0m, in \u001b[0;36m_ConnectionBase.poll\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    255\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_closed()\n\u001b[1;32m    256\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_readable()\n\u001b[0;32m--> 257\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_poll\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/multiprocessing/connection.py:424\u001b[0m, in \u001b[0;36mConnection._poll\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    423\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_poll\u001b[39m(\u001b[38;5;28mself\u001b[39m, timeout):\n\u001b[0;32m--> 424\u001b[0m     r \u001b[38;5;241m=\u001b[39m \u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    425\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mbool\u001b[39m(r)\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/multiprocessing/connection.py:931\u001b[0m, in \u001b[0;36mwait\u001b[0;34m(object_list, timeout)\u001b[0m\n\u001b[1;32m    928\u001b[0m     deadline \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mmonotonic() \u001b[38;5;241m+\u001b[39m timeout\n\u001b[1;32m    930\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m--> 931\u001b[0m     ready \u001b[38;5;241m=\u001b[39m \u001b[43mselector\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mselect\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    932\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m ready:\n\u001b[1;32m    933\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m [key\u001b[38;5;241m.\u001b[39mfileobj \u001b[38;5;28;01mfor\u001b[39;00m (key, events) \u001b[38;5;129;01min\u001b[39;00m ready]\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/selectors.py:416\u001b[0m, in \u001b[0;36m_PollLikeSelector.select\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    414\u001b[0m ready \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m    415\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 416\u001b[0m     fd_event_list \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_selector\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpoll\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    417\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mInterruptedError\u001b[39;00m:\n\u001b[1;32m    418\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m ready\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# lr = 1e-3, 84%\n",
    "# model = resnet18(weights='DEFAULT')\n",
    "# lr = 1e-3, \n",
    "model = resnet34(weights='DEFAULT')\n",
    "# model = resnet152(weights='DEFAULT')\n",
    "in_features = model.fc.in_features\n",
    "model.fc = nn.Linear(in_features, 10)\n",
    "model = model.to(device)\n",
    "\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "\n",
    "max_epochs = 50\n",
    "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)\n",
    "# model.to(device)\n",
    "for epoch in tqdm(range(max_epochs)):\n",
    "    train(model)\n",
    "    scheduler.step()\n",
    "    if (epoch+1) % 5 == 0:\n",
    "        test(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d811d411-5575-494a-9133-0a5d05cae1c9",
   "metadata": {},
   "source": [
    "### scale the image to 224"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c0068a30-5e0c-4e66-8919-c7b4c55d9cce",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "transform_train = transforms.Compose([\n",
    "    transforms.Resize(224),\n",
    "    transforms.RandomHorizontalFlip(p=.40),\n",
    "    transforms.RandomRotation(30),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])\n",
    "\n",
    "transform_test = transforms.Compose([\n",
    "    transforms.Resize(224),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])\n",
    "\n",
    "train_set = torchvision.datasets.CIFAR10(\n",
    "    root='./data', train=True, download=True, transform=transform_train)\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    train_set, batch_size=512, shuffle=True, num_workers=2)\n",
    "\n",
    "test_set = torchvision.datasets.CIFAR10(\n",
    "    root='./data', train=False, download=True, transform=transform_test)\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    test_set, batch_size=250, shuffle=False, num_workers=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0e2a0d24-42b2-49c2-95d3-427985f787d2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e69d89aeb96b471abbdf75c9f5c0d05d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "89d120510a804d0592d08d865ba38aeb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/40 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test loss: 0.341 | acc: 88.670%(8867/10000)\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[6], line 16\u001b[0m\n\u001b[1;32m     14\u001b[0m \u001b[38;5;66;03m# model.to(device)\u001b[39;00m\n\u001b[1;32m     15\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(max_epochs)):\n\u001b[0;32m---> 16\u001b[0m     \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     17\u001b[0m     scheduler\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m     18\u001b[0m     test(model)\n",
      "Cell \u001b[0;32mIn[3], line 15\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(model)\u001b[0m\n\u001b[1;32m     12\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m     13\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n\u001b[0;32m---> 15\u001b[0m train_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     16\u001b[0m _, predicted \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mmax(\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m     17\u001b[0m total \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m targets\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m0\u001b[39m)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# lr = 1e-3, 84%\n",
    "# model = resnet18(weights='DEFAULT')\n",
    "# lr = 1e-3, \n",
    "model = resnet34(weights='DEFAULT')\n",
    "# model = resnet152(weights='DEFAULT')\n",
    "in_features = model.fc.in_features\n",
    "model.fc = nn.Linear(in_features, 10)\n",
    "model = model.to(device)\n",
    "\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "\n",
    "max_epochs = 5\n",
    "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)\n",
    "# model.to(device)\n",
    "for epoch in tqdm(range(max_epochs)):\n",
    "    train(model)\n",
    "    scheduler.step()\n",
    "    test(model)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
